# from spacy_llm.util import assemble
from tqdm import tqdm
import openai
from openai import OpenAI
import os
import time
import json

from utils.util import write_json, read_json, read_txt, is_json
from utils.token_count_decorator import token_count_decorator

class Recognize:
    def __init__(self, extracted_data_path, recognized_store_path):
        self.extracted_data = read_json(extracted_data_path)
        self.recognized_store_path = recognized_store_path
        self.labels_extraction_prompt = read_txt("dsl_design/data/prompt/labels_extraction.txt")
        self.quantities_extraction_prompt = read_txt("dsl_design/data/prompt/scientific_quantities_extraction.txt")
        self.entities_extraction_prompt = read_txt("dsl_design/data/prompt/entities_extraction.txt")
        self.batch_input_path = "dsl_design/data/temp_batch/recognize_batch_input.jsonl"
        self.batch_output_path = "dsl_design/data/temp_batch/recognize_batch_output.jsonl"
        self.concurrent = 9

    def recognize_labels(self, recognized_store_path):
        recognized_data = []
        for extracted in tqdm(self.extracted_data):
            # 一次识别 10 个 sentence
            recognized = []
            while True:
                for i in range(0, len(extracted), 8):
                    sentences = extracted[i:i+8]
                    sentences = [{"sentence": sentence["sentence"], 
                        "opcode": sentence["opcode"],
                        "entities": []} for sentence in sentences]
                    sentences_str = "\n".join([sentence["sentence"] for sentence in sentences])
                    lines = self.__chatgpt_function(self.labels_extraction_prompt.replace("---SENTENCES---", sentences_str)).split("\n")
                    time.sleep(8)
                    sentence_index_init = 0
                    for line in lines:
                        parts = line.split("|")
                        # print("parts len: ", len(parts))
                        if len(parts) == 3:
                            word = parts[0].strip()
                            label = parts[1].strip()
                            # find sentence index
                            sentence_index = self.__find_sentence_index([sentence["sentence"] for sentence in sentences], word, sentence_index_init)
                            if sentence_index != -1:
                                sentences[sentence_index]["entities"].append({
                                    "label": label,
                                    "text": word,
                                })
                                sentence_index_init = sentence_index
                            else:
                                print("Can't find sentence index. Previous index: ", sentence_index_init)
                                print(word, label)

                    recognized.extend(sentences)
                break
            recognized_data.append(recognized)
        write_json(recognized_store_path, recognized_data)
    
    def recognize_labels_2(self):
        # 使用 gpt
        recognized_data = []
        # extracted_len = 0
        # recognized_len = 0
        for k in tqdm(range(0, len(self.extracted_data), self.concurrent)):
            sentences_raw = self.extracted_data[k:k+self.concurrent]
            sentences = [sentence for sentence_raw in sentences_raw for sentence in sentence_raw]
            for sentence_raw in sentences_raw:
                for sentence in sentence_raw:
                    sentence = {
                        "sentence": sentence["sentence"],
                        "opcode": sentence["opcode"],
                        "entities": []
                    }
            self.__empty_jsonl_contents()
            # extracted_len += len(sentences)
            for i in range(0, len(sentences)):
                self.__gpt_batch_store(self.labels_extraction_prompt.replace("---SENTENCES---", sentences[i]["sentence"]), str(i))
            batch_obj = self.__gpt_batch_call()
            results = self.__get_batch_result(batch_obj.id, 3)
            for result in results:
                index = int(result["custom_id"])
                outer_index = 0
                inner_index = 0
                for m in range(0, len(sentences_raw)):
                    if index - len(sentences_raw[m]) >= 0:
                        index -= len(sentences_raw[m])
                    else:
                        outer_index = m
                        inner_index = index
                        break
                sentences_raw[outer_index][inner_index]["entities"] = result["text"]
                # # 计算长度：
                # for sentence_raw in sentences_raw:
                #     recognized_len += len(sentence_raw)
            recognized_data.extend(sentences_raw)
            time.sleep(5)
        # print("Extracted: ", extracted_len)
        # print("Label Recognized: ", recognized_len)
        # print("Fail: ", extracted_len - recognized_len)
        # print("Fail %: ", format((extracted_len - recognized_len) / recognized_len * 100, ".2f"), "%")
        print("Labels recognized")
        write_json(self.recognized_store_path, recognized_data)

    def recognize_quantities(self, recognized_store_path):
        recognized_data_raw = read_json(recognized_store_path)
        recognized_data = []
        for extracted in tqdm(recognized_data_raw):
            # 一次识别 10 个 sentence
            recognized = []
            while True:
                for i in range(0, len(extracted), 8):
                    sentences = extracted[i:i+8]
                    sentences = [{"sentence": sentence["sentence"], 
                        "opcode": sentence["opcode"],
                        "entities": sentence["entities"],
                        "quantities": []} for sentence in sentences]
                    sentences_str = "\n".join([sentence["sentence"] for sentence in sentences])
                    lines = self.__chatgpt_function(self.quantities_extraction_prompt.replace("---SENTENCES---", sentences_str)).split("\n")
                    time.sleep(8)
                    sentence_index_init = 0
                    for line in lines:
                        parts = line.split("|")
                        # print("parts len: ", len(parts))
                        if len(parts) == 2:
                            word = parts[0].strip()
                            quantity = parts[1].strip()
                            # find sentence index
                            sentence_index = self.__find_sentence_index([sentence["sentence"] for sentence in sentences], word, sentence_index_init)
                            if sentence_index != -1:
                                sentences[sentence_index-1]["quantities"].append({
                                    "quantity": quantity,
                                    "text": word
                                })
                                sentence_index_init = sentence_index
                            else:
                                print("Can't find sentence index. Previous index: ", sentence_index_init)
                                print(word, quantity)

                    recognized.extend(sentences)
                break
            recognized_data.append(recognized)
        write_json(recognized_store_path, recognized_data)

    def recognize_quantities_2(self):
        # 使用 gpt
        label_recognized_data = read_json(self.recognized_store_path)
        recognized_data = []
        for k in tqdm(range(0, len(label_recognized_data), self.concurrent)):
            sentences_raw = label_recognized_data[k:k+self.concurrent]
            sentences = [sentence for sentence_raw in sentences_raw for sentence in sentence_raw]
            # prefix_candidate = []
            for i in range(len(sentences_raw)):
                sentence_raw = sentences_raw[i]
                for j in range(len(sentence_raw)):
                    sentences_raw[i][j] = {
                        "sentence": sentences_raw[i][j]["sentence"],
                        "opcode": sentences_raw[i][j]["opcode"],
                        "entities": sentences_raw[i][j]["entities"],
                        "quantities": []
                    }
            self.__empty_jsonl_contents()
            # extracted_len += len(sentences)
            for i in range(0, len(sentences)):
                self.__gpt_batch_store(self.quantities_extraction_prompt.replace("---SENTENCES---", sentences[i]["sentence"]), str(i))
            batch_obj = self.__gpt_batch_call()
            results = self.__get_batch_result(batch_obj.id, 2)
            for result in results:
                index = int(result["custom_id"])
                outer_index = 0
                inner_index = 0
                for m in range(0, len(sentences_raw)):
                    if index - len(sentences_raw[m]) >= 0:
                        index -= len(sentences_raw[m])
                    else:
                        outer_index = m
                        inner_index = index
                        break
                sentences_raw[outer_index][inner_index]["quantities"] = result["text"]
            # # 计算长度：
            # for sentence_raw in sentences_raw:
            #     recognized_len += len(sentence_raw)
            recognized_data.extend(sentences_raw)
            time.sleep(10)
        # print("Label Recognized:: ", extracted_len)
        # print("Quantities Recognized: ", recognized_len)
        # print("Fail: ", extracted_len - recognized_len)
        # print("Fail %: ", format((extracted_len - recognized_len) / recognized_len * 100, ".2f"), "%")
        print("Quantities recognized")
        write_json(self.recognized_store_path, recognized_data)

    def recognize_whole(self):
        recognized_data = []
        for k in tqdm(range(0, len(self.extracted_data), self.concurrent)):
            sentences_raw = self.extracted_data[k:k+self.concurrent]
            sentences = [sentence for sentence_raw in sentences_raw for sentence in sentence_raw]
            for i in range(len(sentences_raw)):
                for j in range(len(sentences_raw[i])):
                    sentences_raw[i][j] = {
                        "sentence": sentences_raw[i][j]["sentence"],
                        "opcode": sentences_raw[i][j]["opcode"],
                        "operation": sentences_raw[i][j]["operation"],
                        "recognized": {}
                    }
            self.__empty_jsonl_contents()
            # extracted_len += len(sentences)
            for i in range(0, len(sentences)):
                self.__gpt_batch_store(self.entities_extraction_prompt.replace("---SENTENCES---", sentences[i]["sentence"]), str(i))
            print("Batch stored")
            batch_obj = self.__gpt_batch_call()
            print("Batch called, waiting for results...")
            results = self.__get_batch_result_2(batch_obj.id)
            print("Results received")
            for result in results:
                index = int(result["custom_id"])
                outer_index = 0
                inner_index = 0
                for m in range(0, len(sentences_raw)):
                    if index - len(sentences_raw[m]) >= 0:
                        index -= len(sentences_raw[m])
                    else:
                        outer_index = m
                        inner_index = index
                        break
                sentences_raw[outer_index][inner_index]["recognized"] = result["text"]
            recognized_data.extend(sentences_raw)
            time.sleep(5)
        print("Recognized")
        write_json(self.recognized_store_path, recognized_data)

    def __find_sentence_index(self, sentences, target_str, start_index=0):
        if start_index >= len(sentences):
            return -1
        for i in range(start_index, len(sentences)):
            sentence = sentences[i]
            if target_str in sentence:
                return i
        return -1

    @token_count_decorator
    def __chatgpt_function(self, content):
        while True:
            # time.sleep(8)
            try:
                client = OpenAI(
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
                chat_completion = client.chat.completions.create(
                    messages=[
                        {"role": "system", "content": "You are a natural language processing model designed for performing NLP tasks."},
                        {"role": "user", "content": content}
                    ],
                    model="gpt-3.5-turbo",
                )
                return chat_completion.choices[0].message.content
            except openai.APIError as error:
                print(error)

    @token_count_decorator
    def __gpt_batch_store(self, content, index):
        standard = {"custom_id": "", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "system", "content": "You are a natural language processing model designed for performing NLP tasks."},{"role": "user", "content": ""}],"max_tokens": 1000}}
        prompt_unit = standard.copy()
        prompt_unit["body"]["messages"][1]["content"] = content
        prompt_unit["custom_id"] = index
        with open(self.batch_input_path, 'a') as file:
            # 将字典转换为JSON字符串并追加到文件
            json_line = json.dumps(prompt_unit)
            file.write(json_line + '\n')

    def __gpt_batch_call(self):
        client = OpenAI()
        batch_input_file = client.files.create(
            file=open(self.batch_input_path, "rb"),
            purpose="batch"
        )
        batch_input_file_id = batch_input_file.id
        batch_obj = client.batches.create(
            input_file_id=batch_input_file_id,
            endpoint="/v1/chat/completions",
            completion_window="24h",
        )
        return batch_obj
    
    def __get_batch_result(self, batch_id, parts_num):
        client = OpenAI()
        while True:
            batch = client.batches.retrieve(batch_id)
            if batch.status == "completed":
                result_file_id = batch.output_file_id
                result = client.files.content(result_file_id).content
                result_file_name = self.batch_output_path
                with open(result_file_name, 'wb') as file:
                    file.write(result)
                results = []
                with open(result_file_name, 'r') as file:
                    for line in file:
                        # Parsing the JSON string into a dict and appending to the list of results
                        json_object = json.loads(line.strip())
                        results.append(json_object)
                return_results = []
                for r in results:
                    lines = r["response"]["body"]["choices"][0]["message"]["content"].split("\n")
                    text = []
                    for line in lines:
                        parts = line.split("|")
                        # print("parts len: ", len(parts))
                        if len(parts) == parts_num:
                            word = parts[0].strip()
                            label = parts[1].strip()
                            text.append({
                                "word": word,
                                "label": label
                            })
                    return_results.append({
                        "custom_id": r["custom_id"],
                        "text": text
                    })
                return return_results
            elif batch.status == "failed" :
                print("Batch failed")
                return []
            elif batch.status == "expired":
                print("Batch expired")
                return []
            elif batch.status == "cancelled":
                print("Batch cancelled")
                return []
            elif batch.status == "cancelling":
                print("Batch cancelling")
                return []
            else:
                time.sleep(3)

    def __get_batch_result_2(self, batch_id):
        client = OpenAI()
        while True:
            batch = client.batches.retrieve(batch_id)
            if batch.status == "completed":
                result_file_id = batch.output_file_id
                result = client.files.content(result_file_id).content
                result_file_name = self.batch_output_path
                with open(result_file_name, 'wb') as file:
                    file.write(result)
                results = []
                with open(result_file_name, 'r') as file:
                    for line in file:
                        # Parsing the JSON string into a dict and appending to the list of results
                        json_object = json.loads(line.strip())
                        results.append(json_object)
                return_results = []
                for r in results:
                    lines = r["response"]["body"]["choices"][0]["message"]["content"]
                    return_results.append({
                        "custom_id": r["custom_id"],
                        "text": json.loads(lines) if is_json(lines) else {}
                    })
                return return_results
            elif batch.status == "failed" :
                print("Batch failed")
                return []
            elif batch.status == "expired":
                print("Batch expired")
                return []
            elif batch.status == "cancelled":
                print("Batch cancelled")
                return []
            elif batch.status == "cancelling":
                print("Batch cancelling")
                return []
            else:
                time.sleep(3)

    def __empty_jsonl_contents(self):
        if os.path.exists(self.batch_input_path):
            with open(self.batch_input_path, 'w') as file:
                file.write('')
